# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from gpu.host import DeviceContext
from layout import LayoutTensor, Layout, RuntimeLayout, UNKNOWN_VALUE
from memory import LegacyUnsafePointer as UnsafePointer
from nn.gather_scatter import _gather_nd_impl, gather_nd_shape

from utils import IndexList


def execute_gather_nd_test[
    data_type: DType,
    indices_type: DType,
    data_rank: Int,
    indices_rank: Int,
    batch_dims: Int,
](
    data_host_ptr: UnsafePointer[Scalar[data_type]],
    data_shape: IndexList[data_rank],
    indices_host_ptr: UnsafePointer[Scalar[indices_type]],
    indices_shape: IndexList[indices_rank],
    ctx: DeviceContext,
):
    # Compute sizes
    var data_size = 1
    for i in range(data_rank):
        data_size *= data_shape[i]
    var indices_size = 1
    for i in range(indices_rank):
        indices_size *= indices_shape[i]

    # Create host LayoutTensors
    comptime data_layout = Layout.row_major[data_rank]()
    comptime indices_layout = Layout.row_major[indices_rank]()
    var data_host_tensor = LayoutTensor[data_type, data_layout](
        data_host_ptr,
        RuntimeLayout[data_layout].row_major(data_shape),
    )
    var indices_host_tensor = LayoutTensor[indices_type, indices_layout](
        indices_host_ptr,
        RuntimeLayout[indices_layout].row_major(indices_shape),
    )

    # Create device buffers and copy data to them
    var data_device = ctx.enqueue_create_buffer[data_type](data_size)
    var indices_device = ctx.enqueue_create_buffer[indices_type](indices_size)

    comptime output_rank = 1

    var output_shape = gather_nd_shape[
        output_rank,
        data_type,
        indices_type,
        batch_dims,
    ](
        data_host_tensor,
        indices_host_tensor,
    )

    # Compute output size
    var output_size = 1
    for i in range(output_shape.size):
        output_size *= output_shape[i]

    var actual_output_device = ctx.enqueue_create_buffer[data_type](output_size)
    comptime actual_output_layout = Layout.row_major[output_rank]()

    ctx.enqueue_copy(data_device, data_host_ptr)
    ctx.enqueue_copy(indices_device, indices_host_ptr)

    # Create device LayoutTensors
    var data_device_tensor = LayoutTensor[data_type, data_layout, MutAnyOrigin](
        data_device.unsafe_ptr(),
        RuntimeLayout[data_layout].row_major(data_shape),
    )
    var indices_device_tensor = LayoutTensor[
        indices_type, indices_layout, MutAnyOrigin
    ](
        indices_device.unsafe_ptr(),
        RuntimeLayout[indices_layout].row_major(indices_shape),
    )
    var actual_output_tensor = LayoutTensor[
        data_type, actual_output_layout, MutAnyOrigin
    ](
        actual_output_device.unsafe_ptr(),
        RuntimeLayout[actual_output_layout].row_major(output_shape),
    )

    # execute the kernel
    _gather_nd_impl[batch_dims, target="gpu"](
        data_device_tensor,
        indices_device_tensor,
        actual_output_tensor,
        ctx,
    )
    # Give the kernel an opportunity to raise the error before finishing the test.
    ctx.synchronize()

    # Cleanup device buffers
    _ = data_device^
    _ = indices_device^
    _ = actual_output_device^


fn test_gather_nd_oob(ctx: DeviceContext) raises:
    # Example 1
    comptime batch_dims = 0
    comptime data_rank = 2
    comptime data_type = DType.int32
    comptime data_layout = Layout.row_major[data_rank]()
    var data_shape = IndexList[data_rank](2, 2)
    var data_size = 4
    var data_host_ptr = UnsafePointer[Scalar[data_type]].alloc(data_size)
    var data_tensor = LayoutTensor[data_type, data_layout](
        data_host_ptr,
        RuntimeLayout[data_layout].row_major(data_shape),
    )

    data_tensor[0, 0] = 0
    data_tensor[0, 1] = 1
    data_tensor[1, 0] = 2
    data_tensor[1, 1] = 3

    comptime indices_rank = 2
    comptime indices_layout = Layout.row_major[indices_rank]()
    var indices_shape = IndexList[indices_rank](2, 2)
    var indices_size = 4
    var indices_host_ptr = UnsafePointer[Scalar[DType.int64]].alloc(
        indices_size
    )
    var indices_tensor = LayoutTensor[DType.int64, indices_layout](
        indices_host_ptr,
        RuntimeLayout[indices_layout].row_major(indices_shape),
    )

    indices_tensor[0, 0] = 0
    indices_tensor[0, 1] = 0
    indices_tensor[1, 0] = 1
    indices_tensor[1, 1] = 100  # wildly out of bounds

    execute_gather_nd_test[
        data_type, DType.int64, data_rank, indices_rank, batch_dims
    ](data_host_ptr, data_shape, indices_host_ptr, indices_shape, ctx)
    ctx.synchronize()

    # Cleanup host memory
    data_host_ptr.free()
    indices_host_ptr.free()


def main():
    with DeviceContext() as ctx:
        # CHECK: {{.*}}data index out of bounds{{.*}}
        test_gather_nd_oob(ctx)
